{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d58dece3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>item_id</th>\n",
       "      <th>score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1615</td>\n",
       "      <td>12977</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>782</td>\n",
       "      <td>13124</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1084</td>\n",
       "      <td>16475</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>593</td>\n",
       "      <td>8690</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>127</td>\n",
       "      <td>14225</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  item_id  score\n",
       "0     1615    12977      1\n",
       "1      782    13124      0\n",
       "2     1084    16475      0\n",
       "3      593     8690      0\n",
       "4      127    14225      1"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the data from files\n",
    "import pandas as pd\n",
    "\n",
    "train_data = pd.read_csv(\"../../data/a0910/train.csv\")\n",
    "valid_data = pd.read_csv(\"../../data/a0910/valid.csv\")\n",
    "test_data = pd.read_csv(\"../../data/a0910/test.csv\")\n",
    "df_item = pd.read_csv(\"../../data/a0910/item.csv\")\n",
    "item2knowledge = {}\n",
    "knowledge_set = set()\n",
    "for i, s in df_item.iterrows():\n",
    "    item_id, knowledge_codes = s['item_id'], list(set(eval(s['knowledge_code'])))\n",
    "    item2knowledge[item_id] = knowledge_codes\n",
    "    knowledge_set.update(knowledge_codes)\n",
    "\n",
    "train_data.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f050269d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(186049, 25606, 55760)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data), len(valid_data), len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f5404492",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4128, 17746, 123)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Get basic data info for model initialization\n",
    "import numpy as np\n",
    "user_n = np.max(train_data['user_id'])\n",
    "item_n = np.max([np.max(train_data['item_id']), np.max(valid_data['item_id']), np.max(test_data['item_id'])])\n",
    "knowledge_n = np.max(list(knowledge_set))\n",
    "\n",
    "user_n, item_n, knowledge_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b24caf23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<torch.utils.data.dataloader.DataLoader at 0x1e80529fc40>,\n",
       " <torch.utils.data.dataloader.DataLoader at 0x1e80784faf0>,\n",
       " <torch.utils.data.dataloader.DataLoader at 0x1e80784f8b0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Transform data to torch Dataloader (i.e., batchify)\n",
    "# batch_size is set to 32\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "batch_size = 32\n",
    "def transform(user, item, item2knowledge, score, batch_size):\n",
    "    knowledge_emb = torch.zeros((len(item), knowledge_n))\n",
    "    for idx in range(len(item)):\n",
    "        knowledge_emb[idx][np.array(item2knowledge[item[idx]]) - 1] = 1.0\n",
    "\n",
    "    data_set = TensorDataset(\n",
    "        torch.tensor(user, dtype=torch.int64) - 1,  # (1, user_n) to (0, user_n-1)\n",
    "        torch.tensor(item, dtype=torch.int64) - 1,  # (1, item_n) to (0, item_n-1)\n",
    "        knowledge_emb,\n",
    "        torch.tensor(score, dtype=torch.float32)\n",
    "    )\n",
    "    return DataLoader(data_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "\n",
    "train_set, valid_set, test_set = [\n",
    "    transform(data[\"user_id\"], data[\"item_id\"], item2knowledge, data[\"score\"], batch_size)\n",
    "    for data in [train_data, valid_data, test_data]\n",
    "]\n",
    "\n",
    "train_set, valid_set, test_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "70265e93",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.getLogger().setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1acfaf6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from EduCDM import KaNCD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fb70ab1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:traing... (lr=0.002)\n",
      "Epoch 0: 100%|████████████████████████████████████████████████████████████████████| 5815/5815 [00:47<00:00, 123.17it/s]\n",
      "INFO:root:[Epoch 0] average loss: 0.569911\n",
      "INFO:root:eval ... \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] average loss: 0.569911\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|███████████████████████████████████████████████████████████████████| 801/801 [00:02<00:00, 270.70it/s]\n",
      "INFO:root:[Epoch 0] auc: 0.763524, acc: 0.734476\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] auc: 0.763524, acc: 0.734476\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|████████████████████████████████████████████████████████████████████| 5815/5815 [00:49<00:00, 117.69it/s]\n",
      "INFO:root:[Epoch 1] average loss: 0.492857\n",
      "INFO:root:eval ... \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 1] average loss: 0.492857\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|███████████████████████████████████████████████████████████████████| 801/801 [00:02<00:00, 353.55it/s]\n",
      "INFO:root:[Epoch 1] auc: 0.766779, acc: 0.734984\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 1] auc: 0.766779, acc: 0.734984\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|████████████████████████████████████████████████████████████████████| 5815/5815 [00:46<00:00, 125.03it/s]\n",
      "INFO:root:[Epoch 2] average loss: 0.463993\n",
      "INFO:root:eval ... \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 2] average loss: 0.463993\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|███████████████████████████████████████████████████████████████████| 801/801 [00:02<00:00, 314.35it/s]\n",
      "INFO:root:[Epoch 2] auc: 0.766093, acc: 0.731352\n",
      "INFO:root:save parameters to kancd.snapshot\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 2] auc: 0.766093, acc: 0.731352\n"
     ]
    }
   ],
   "source": [
    "cdm = KaNCD(exer_n=item_n, student_n=user_n, knowledge_n=knowledge_n, mf_type='gmf', dim=20)\n",
    "cdm.train(train_set, valid_set, epoch_n=3, device=\"cuda\", lr=0.002)\n",
    "cdm.save(\"kancd.snapshot\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9bdb20ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:load parameters from kancd.snapshot\n",
      "INFO:root:eval ... \n",
      "Evaluating: 100%|█████████████████████████████████████████████████████████████████| 1743/1743 [00:05<00:00, 334.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc: 0.768029, accuracy: 0.731923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "cdm.load(\"kancd.snapshot\")\n",
    "auc, accuracy = cdm.eval(test_set, device=\"cuda\")\n",
    "print(\"auc: %.6f, accuracy: %.6f\" % (auc, accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b90ae5f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
